#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import copy
import logging
from io import BytesIO

import numpy as np
import torch
from d2go.utils.helper import retryable
from detectron2.data import detection_utils as utils, transforms as T
from detectron2.data.transforms.augmentation import (
    AugInput,
    AugmentationList,
)
from PIL import Image

from .build import D2GO_DATA_MAPPER_REGISTRY

logger = logging.getLogger(__name__)

PREFETCHED_FILE_NAME = "prefetch_image"
PREFETCHED_SEM_SEG_FILE_NAME = "prefetch_sem_seg"


def read_image_with_prefetch(file_name, format=None, prefetched=None):
    if prefetched is None:
        return utils.read_image(file_name, format)

    image = Image.open(BytesIO(prefetched.numpy().view()))
    # work around this bug: https://github.com/python-pillow/Pillow/issues/3973
    image = utils._apply_exif_orientation(image)
    return utils.convert_PIL_to_numpy(image, format)


@D2GO_DATA_MAPPER_REGISTRY.register()
class D2GoDatasetMapper(object):
    def __init__(self, cfg, is_train=True, image_loader=None, tfm_gens=None):
        self.tfm_gens = (
            tfm_gens
            if tfm_gens is not None
            else utils.build_transform_gen(cfg, is_train)
        )

        if cfg.INPUT.CROP.ENABLED and is_train:
            self.crop_gen = T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)
            # D2GO NOTE: when INPUT.CROP.ENABLED, don't allow using RandomCropOp
            assert all(not isinstance(gen, T.RandomCrop) for gen in self.tfm_gens)
        else:
            self.crop_gen = None

        # fmt: off
        self.img_format     = cfg.INPUT.FORMAT  # noqa
        self.mask_on        = cfg.MODEL.MASK_ON  # noqa
        self.mask_format    = cfg.INPUT.MASK_FORMAT  # noqa
        self.keypoint_on    = cfg.MODEL.KEYPOINT_ON  # noqa
        # fmt: on
        if self.keypoint_on and is_train:
            # Flip only makes sense in training
            self.keypoint_hflip_indices = utils.create_keypoint_hflip_indices(
                cfg.DATASETS.TRAIN
            )
        else:
            self.keypoint_hflip_indices = None

        self.load_proposals = cfg.MODEL.LOAD_PROPOSALS
        if self.load_proposals:
            self.proposal_min_box_size = cfg.MODEL.PROPOSAL_GENERATOR.MIN_SIZE
            self.proposal_topk = (
                cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
                if is_train
                else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
            )

        self.is_train = is_train

        # Setup image loader:
        self.image_loader = image_loader
        self.backfill_size = cfg.D2GO_DATA.MAPPER.BACKFILL_SIZE
        self.retry = cfg.D2GO_DATA.MAPPER.RETRY
        self.catch_exception = cfg.D2GO_DATA.MAPPER.CATCH_EXCEPTION

        if self.backfill_size:
            if cfg.DATALOADER.ASPECT_RATIO_GROUPING:
                logger.warning(
                    "ASPECT_RATIO_GROUPING may not work if image's width & height"
                    " are not given in json dataset when calling extended_coco_load,"
                    " if you encounter issue, consider disable ASPECT_RATIO_GROUPING."
                )

        self._error_count = 0
        self._total_counts = 0
        self._error_types = {}

    def _original_call(self, dataset_dict):
        """
        Modified from detectron2's original __call__ in DatasetMapper
        """
        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below

        image = self._read_image(dataset_dict, format=self.img_format)
        if not self.backfill_size:
            utils.check_image_size(dataset_dict, image)

        image, dataset_dict = self._custom_transform(image, dataset_dict)

        inputs = AugInput(image=image)
        if "annotations" not in dataset_dict:
            transforms = AugmentationList(
                ([self.crop_gen] if self.crop_gen else []) + self.tfm_gens
            )(inputs)
            image = inputs.image
        else:
            # pass additional arguments, will only be used when the Augmentation
            #   takes `annotations` as input
            inputs.annotations = dataset_dict["annotations"]
            # Crop around an instance if there are instances in the image.
            if self.crop_gen:
                crop_tfm = utils.gen_crop_transform_with_instance(
                    self.crop_gen.get_crop_size(image.shape[:2]),
                    image.shape[:2],
                    np.random.choice(dataset_dict["annotations"]),
                )
                image = crop_tfm.apply_image(image)
            transforms = AugmentationList(self.tfm_gens)(inputs)
            image = inputs.image
            if self.crop_gen:
                transforms = crop_tfm + transforms

        image_shape = image.shape[:2]  # h, w
        if image.ndim == 2:
            image = np.expand_dims(image, 2)
        dataset_dict["image"] = torch.as_tensor(
            image.transpose(2, 0, 1).astype("float32")
        )
        # Can use uint8 if it turns out to be slow some day

        if self.load_proposals:
            utils.transform_proposals(
                dataset_dict,
                image_shape,
                transforms,
                proposal_topk=self.proposal_topk,
                min_box_size=self.proposal_min_box_size,
            )

        if not self.is_train:
            dataset_dict.pop("annotations", None)
            dataset_dict.pop("sem_seg_file_name", None)
            return dataset_dict

        if "annotations" in dataset_dict:
            for anno in dataset_dict["annotations"]:
                if not self.mask_on:
                    anno.pop("segmentation", None)
                if not self.keypoint_on:
                    anno.pop("keypoints", None)

            annos = [
                utils.transform_instance_annotations(
                    obj,
                    transforms,
                    image_shape,
                    keypoint_hflip_indices=self.keypoint_hflip_indices,
                )
                for obj in dataset_dict.pop("annotations")
                if obj.get("iscrowd", 0) == 0
            ]
            instances = utils.annotations_to_instances(
                annos, image_shape, mask_format=self.mask_format
            )
            # Create a tight bounding box from masks, useful when image is cropped
            if self.crop_gen and instances.has("gt_masks"):
                instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
            dataset_dict["instances"] = utils.filter_empty_instances(instances)

        if "sem_seg_file_name" in dataset_dict:
            sem_seg_gt = read_image_with_prefetch(
                dataset_dict.pop("sem_seg_file_name"),
                "L",
                prefetched=dataset_dict.get(PREFETCHED_SEM_SEG_FILE_NAME, None),
            ).squeeze(2)
            sem_seg_gt = transforms.apply_segmentation(sem_seg_gt)
            sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))
            dataset_dict["sem_seg"] = sem_seg_gt

        # extend standard D2 semantic segmentation to support multiple segmentation
        # files, each file can represent a class
        if "multi_sem_seg_file_names" in dataset_dict:
            raise NotImplementedError()

        if "_post_process_" in dataset_dict:
            proc_func = dataset_dict.pop("_post_process_")
            dataset_dict = proc_func(dataset_dict)

        return dataset_dict

    def __call__(self, dataset_dict):
        self._total_counts += 1

        @retryable(num_tries=self.retry, sleep_time=0.1)
        def _f():
            return self._original_call(dataset_dict)

        if not self.catch_exception:
            return _f()

        try:
            return _f()
        except Exception as e:
            self._error_count += 1
            # if self._error_count % 10 == 1:
            #     # print the stacktrace for easier debugging
            #     traceback.print_exc()
            error_type = type(e).__name__
            self._error_types[error_type] = self._error_types.get(error_type, 0) + 1

            if self._error_count % 100 == 0:
                logger.warning(
                    "{}Error when applying transform for dataset_dict: {};"
                    " error rate {}/{} ({:.2f}%), msg: {}".format(
                        self._get_logging_prefix(),
                        dataset_dict,
                        self._error_count,
                        self._total_counts,
                        100.0 * self._error_count / self._total_counts,
                        repr(e),
                    )
                )
                self._log_error_type_stats()

            # NOTE: the contract with MapDataset allows return `None` such that
            # it'll randomly use other element in the dataset. We use this
            # feature to handle error.
            return None

    def _get_logging_prefix(self):
        worker_info = torch.utils.data.get_worker_info()
        if not worker_info:
            return ""

        prefix = "[worker: {}/{}] ".format(worker_info.id, worker_info.num_workers)
        return prefix

    def _log_error_type_stats(self):
        error_type_count_msgs = [
            "{}: {}/{} ({}%)".format(
                k, v, self._total_counts, 100.0 * v / self._total_counts
            )
            for k, v in self._error_types.items()
        ]
        logger.warning(
            "{}Error statistics:\n{}".format(
                self._get_logging_prefix(), "\n".join(error_type_count_msgs)
            )
        )

    def _read_image(self, dataset_dict, format=None):
        if not (self.image_loader and self.image_loader.support(dataset_dict)):
            # fallback to use D2's read_image
            image = read_image_with_prefetch(
                dataset_dict["file_name"],
                format=format,
                prefetched=dataset_dict.get(PREFETCHED_FILE_NAME),
            )
            if self.backfill_size:
                h, w, _ = image.shape
                dataset_dict["width"] = w
                dataset_dict["height"] = h
            return image

        image = self.image_loader(dataset_dict)
        if self.backfill_size:
            dataset_dict["width"] = image.width
            dataset_dict["height"] = image.height

        return utils.convert_PIL_to_numpy(image, format)

    def _custom_transform(self, image, dataset_dict):
        """
        Override this method to inject custom transform.
        """
        return image, dataset_dict

    def __repr__(self):
        return (
            self.__class__.__name__
            + ":\n"
            + "\n".join(
                [
                    "  is_train: {}".format(self.is_train),
                    "  image_loader: {}".format(self.image_loader),
                    "  tfm_gens: \n{}".format(
                        "\n".join(["    - {}".format(x) for x in self.tfm_gens])
                    ),
                ]
            )
        )
